import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 1. 生成数据
torch.manual_seed(42)
num_samples = 1000

s1 = torch.sin(torch.linspace(0, 8 * torch.pi, num_samples))  # 正弦波
s2 = torch.sign(torch.sin(torch.linspace(0, 8 * torch.pi, num_samples)))  # 方波
S = torch.stack([s1, s2])  # (2, num_samples)

# 2. 生成混合信号 X = A @ S
mixing_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.0]], dtype=torch.float32)
X = mixing_matrix @ S  # (2, num_samples)

# 3. 数据预处理 (去中心化)
X_mean = X.mean(dim=1, keepdim=True)
X_centered = X - X_mean

# 4. 白化处理 (ZCA 白化)
cov = (X_centered @ X_centered.T) / num_samples
eigvals, eigvecs = torch.linalg.eigh(cov)
eigvals = torch.clamp(eigvals, min=1e-5)  # 避免负数
whitening_matrix = eigvecs @ torch.diag(1.0 / torch.sqrt(eigvals)) @ eigvecs.T
X_white = whitening_matrix @ X_centered  # 白化后的数据

# 5. 定义 ICA 模型
class ICA(nn.Module):
    def __init__(self, n_components):
        super().__init__()
        self.W = nn.Parameter(torch.eye(n_components))  # 初始化为单位矩阵

    def forward(self, X):
        return self.W @ X

# 6. 训练 ICA
ica = ICA(n_components=2)
optimizer = optim.Adam([ica.W], lr=0.01)

def neg_entropy(y):
    return torch.mean(torch.tanh(y), dim=1)

num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    Y = ica(X_white)  # 通过 W 提取信号
    loss = -torch.sum(neg_entropy(Y))  # 负熵最大化
    loss.backward()
    optimizer.step()

    # 7. 使用 QR 分解保持 W 近似正交
    with torch.no_grad():
        ica.W.copy_(torch.linalg.qr(ica.W)[0])  # QR 正交化

# 8. 信号恢复
separated = ica(X_white).detach().cpu().numpy()  # 确保 NumPy 兼容性

# 9. 绘图
plt.figure(figsize=(10, 5))

plt.subplot(3, 1, 1)
plt.plot(S.T.detach().cpu().numpy())  # 确保 NumPy 兼容
plt.title("Original Source Signals")

plt.subplot(3, 1, 2)
plt.plot(X.T.detach().cpu().numpy())  # 确保 NumPy 兼容
plt.title("Mixed Signals")

plt.subplot(3, 1, 3)
plt.plot(separated.T)  # 直接使用 NumPy 数据
plt.title("Recovered Signals (ICA)")

plt.tight_layout()
plt.show()
